# coding=utf-8
import datetime
import io
import uuid

import falcon
from playhouse.shortcuts import model_to_dict

from bgtask import submit
from db.base import db
from db.models import Task, TaskStep, TaskExecRecord, StepExecRecord, ExecStatus, Action, SSHServer
from libs.ssh import SSHClient


class TaskResource:
    def on_get(self, req: falcon.Request, resp: falcon.Response, resource_id: int = 0):
        if resource_id == 0:
            total = Task.select().count()
            page = req.get_param_as_int("page", default=1)
            page_size = req.get_param_as_int("page_size", default=10)

            data = []
            tasks = Task.select().order_by(-Task.created_time).paginate(page, page_size)
            for i in tasks.dicts():
                last_exec_record = TaskExecRecord.select().where(TaskExecRecord.task_id == i["id"]).order_by(
                    -TaskExecRecord.created_time).first()
                i.update({
                    "last_exec_status": last_exec_record.status
                })
                data.append(i)
            resp.media = {
                "total": total,
                "page": page,
                "page_size": page_size,
                "data": data
            }
        else:
            resource = Task.get_by_id(resource_id)
            resp.media = model_to_dict(resource)

    def on_get_exec_record(self, req: falcon.Request, resp: falcon.Response, resource_id: int):
        pass

    def on_post(self, req: falcon.Request, resp: falcon.Response):
        data = req.get_media()
        with db.atomic():
            new_resource = Task()
            new_resource.name = data.get("name")
            new_resource.desc = data.get("desc")
            new_resource.save()

            for order, i in enumerate(data.get("steps")):
                task_step = TaskStep()
                task_step.task_id = new_resource.id
                task_step.action_id = i.get("action_id")
                task_step.ssh_server_id = i.get("ssh_server_id")
                task_step.order = order
                task_step.save()

    def on_patch(self, req: falcon.Request, resp: falcon.Response, resource_id: int):
        try:
            old_resource = Task.get_by_id(resource_id)
        except Task.DoesNotExist:
            raise falcon.HTTPBadRequest(title="Resource not exist", description=resource_id)

        if TaskExecRecord.select().where(TaskExecRecord.task_id == resource_id,
                                         TaskExecRecord.status == ExecStatus.RUNNING).exists:
            raise falcon.HTTPBadRequest(title="Task is running, can not update")

        data = req.get_media()
        with db.atomic():
            TaskStep.delete().where(TaskStep.task_id == resource_id).execute()
            TaskExecRecord.delete().where(TaskExecRecord.task_id == resource_id).execute()
            StepExecRecord.delete().where(StepExecRecord.task_id == resource_id).execute()

            old_resource.name = data.get("name")
            old_resource.desc = data.get("desc")
            old_resource.save()

            for order, i in enumerate(data.get("steps")):
                task_step = TaskStep()
                task_step.task_id = old_resource.id
                task_step.action_id = i.get("action_id")
                task_step.ssh_server_id = i.get("ssh_server_id")
                task_step.order = order
                task_step.save()

    def on_delete(self, req: falcon.Request, resp: falcon.Response, resource_id: int):
        if TaskExecRecord.select().where(TaskExecRecord.task_id == resource_id,
                                         TaskExecRecord.status == ExecStatus.RUNNING).exists:
            raise falcon.HTTPBadRequest(title="Task is running, can not delete")

        with db.atomic():
            Task.delete_by_id(resource_id)
            TaskStep.delete().where(TaskStep.task_id == resource_id).execute()
            TaskExecRecord.delete().where(TaskExecRecord.task_id == resource_id).execute()
            StepExecRecord.delete().where(StepExecRecord.task_id == resource_id).execute()

    def on_get_execute(self, req: falcon.Request, resp: falcon.Response, resource_id: int):
        with db.atomic():
            task_exec_record = TaskExecRecord()
            task_exec_record.task_id = resource_id
            task_exec_record.status = ExecStatus.CREATED
            task_exec_record.save()

            bgtask_id = submit(self.execute, resource_id, task_exec_record.id)
            if not bgtask_id:
                raise falcon.HTTPInternalServerError(title="task queue has been full")

    def execute(self, task_id, task_exec_record_id):
        # 1. 标记任务状态为 运行中
        task_exec_record = TaskExecRecord.get_by_id(task_exec_record_id)
        task_exec_record.status = ExecStatus.RUNNING
        task_exec_record.start_time = datetime.datetime.now()
        task_exec_record.save()

        # 2. 按照 order 排序执行步骤
        steps = TaskStep.select().where(TaskStep.task_id == task_id).order_by(TaskStep.order)
        for index, step in enumerate(steps):
            step_exec_record = StepExecRecord()
            step_exec_record.task_id = task_id
            step_exec_record.task_exec_record_id = task_exec_record.id
            step_exec_record.task_step_id = step.id
            step_exec_record.status = ExecStatus.RUNNING
            step_exec_record.start_time = datetime.datetime.now()
            step_exec_record.save()

            try:
                # 3. 将 action 内容写到临时文件，传输到目标机器进行执行
                ssh_server = SSHServer.get_by_id(step.ssh_server_id)
                ssh_client = SSHClient.from_param(ssh_server.ipaddress, ssh_server.port,
                                                  ssh_server.user, ssh_server.password)

                action = Action.get_by_id(step.action_id)
                temp = io.StringIO()
                temp.write(action.content)
                temp.seek(0)
                remote_temp = f"/tmp/{uuid.uuid4().hex}"
                ssh_client.upload(temp, remote_temp)
                temp.close()
                _, stdout, stderr = ssh_client.exec_command(" ".join([action.exec, remote_temp]))

                if stderr:
                    step_exec_record.status = ExecStatus.FAIL
                    step_exec_record.return_code = -1
                    step_exec_record.return_text = stderr.read().strip().decode("utf-8")
                    step_exec_record.end_time = datetime.datetime.now()
                    step_exec_record.save()

                    # 如果任一步骤出错，就停止任务
                    task_exec_record.status = ExecStatus.FAIL
                    task_exec_record.end_time = datetime.datetime.now()
                    task_exec_record.save()
                    break
                else:
                    step_exec_record.status = ExecStatus.SUCCESS
                    step_exec_record.return_code = 0
                    step_exec_record.return_text = stdout.read().strip().decode("utf-8")
                    step_exec_record.end_time = datetime.datetime.now()
                    step_exec_record.save()
            except Exception as e:
                step_exec_record.status = ExecStatus.FAIL
                step_exec_record.error_message = str(e)
                step_exec_record.end_time = datetime.datetime.now()
                step_exec_record.save()
        else:
            task_exec_record.status = ExecStatus.SUCCESS
            task_exec_record.end_time = datetime.datetime.now()
            task_exec_record.save()
